package nsalt.func

import chisel3._
import chisel3.util._

import nsalt._
import nsalt.arch._
import nsalt.util._


class ArithLogicPort extends FuncUnitPort with Config {
  val offset    = Input(UInt(XLEN.W))
  val ctrlFlow  = Flipped(new CtrlFlowPort())
  val redirect  = new RedirectPort()
}

// https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/backend/fu/ALU.scala
class ArithLogic() extends Module with Config {

  val io = IO(new Bundle{
    val alu  = new ArithLogicPort()
    val feed = new FeedbackPort()
  })

  val valid = io.alu.in.valid
  val src1  = io.alu.in.bits.src1
  val src2  = io.alu.in.bits.src2
  val oper  = io.alu.in.bits.oper

  def access(valid: Bool, src1: UInt, src2: UInt, oper: UInt): UInt = {
    this.valid := valid
    this.src1 := src1
    this.src2 := src2
    this.oper := oper
    io.alu.out.bits
  }

  // ADD
  val isSub = !ArithLogicOperType.isAdd(oper)
  val addRes = (src1 +& (src2 ^ Fill(XLEN, isSub))) + isSub

  // XOR
  val xorRes = src1 ^ src2

  // SLTU (Set on Less Than unsigned)
  val sltu = !addRes(XLEN)
  
  // SLT (Set on Less Than)
  val slt = xorRes(XLEN-1) ^ sltu

  // Shifted Source Operand 1 
  val shsrc1 = LookupTreeDefault(oper, src1(XLEN-1,0), List(
    ArithLogicOperType.srlw -> ZeroExt(src1(31,0), XLEN),
    ArithLogicOperType.sraw -> SignExt(src1(31,0), XLEN)
  ))

  // shamt: conventional abbrev for "Shift Amount"
  val shamt = Mux(
    ArithLogicOperType.isWordOper(oper), 
    src2(4, 0), 
    if (XLEN == 64) src2(5, 0) else src2(4, 0)
  )

  // Arith & Logic result
  val res = LookupTreeDefault(oper(3, 0), addRes, List(
    ArithLogicOperType.sll  -> ((shsrc1  << shamt)(XLEN-1, 0)),
    ArithLogicOperType.slt  -> ZeroExt(slt, XLEN),
    ArithLogicOperType.sltu -> ZeroExt(sltu, XLEN),
    ArithLogicOperType.xor  -> xorRes,
    ArithLogicOperType.srl  -> (shsrc1  >> shamt),
    ArithLogicOperType.or   -> (src1  |  src2),
    ArithLogicOperType.and  -> (src1  &  src2),
    ArithLogicOperType.sra  -> ((shsrc1.asSInt >> shamt).asUInt))
  )

  // sign extended
  val result = Mux(
    ArithLogicOperType.isWordOper(oper),
    SignExt(res(31,0), 64),
    res
  )

  val isCond   = ArithLogicOperType.isCond(oper)
  val isBranch = ArithLogicOperType.isBranch(oper)
  
  val branchTaken = LookupTree(
    ArithLogicOperType.getBranchType(oper),
    List(
      ArithLogicOperType.getBranchType(ArithLogicOperType.beq)  -> !xorRes.orR,
      ArithLogicOperType.getBranchType(ArithLogicOperType.blt)  -> slt,
      ArithLogicOperType.getBranchType(ArithLogicOperType.bltu) -> sltu
    )
  ) ^ ArithLogicOperType.isCondNeg(oper)

  // Calculate branch destination, and correct prediction if missed.
  val branchDest = Mux(
    isCond,
    io.alu.ctrlFlow.pc + io.alu.offset,
    addRes
  )(VIRT_MEM_ADDR_LEN - 1, 0)

  val branchIndex = io.alu.ctrlFlow.branchIndex(0)
  val isMispred = Mux(
    !branchTaken && isCond,
    branchIndex,
    !branchIndex || (io.alu.redirect.dest =/= io.alu.ctrlFlow.pcPred)
  )

  val compressed = (io.alu.ctrlFlow.instr(1,0) =/= "b11".U)

  io.alu.redirect.dest := Mux(
    !branchTaken && isCond, 
    Mux(
      compressed,
      io.alu.ctrlFlow.pc + 2.U,
      io.alu.ctrlFlow.pc + 4.U
    ), 
    branchDest
  )

  io.alu.redirect.valid := valid && isBranch && isMispred 

  // val redirectRtype = if (EnableOutOfOrderExec) 1.U else 0.U
  // io.redirect.rtype := redirectRtype
  io.alu.redirect.wholeFlushed := 0.U

  // mark redirect type as speculative exec fix
  // may be can be moved to ISU to calculate pc + 4
  // this is actually for jal and jalr to write pc + 4/2 to rd
  io.alu.out.bits := Mux(
    isBranch, 
    Mux(
      !compressed, 
      SignExt(io.alu.ctrlFlow.pc, ADDR_BITS) + 4.U,
      SignExt(io.alu.ctrlFlow.pc, ADDR_BITS) + 2.U),
    result 
  )

  io.alu.in.ready  := io.alu.out.ready
  io.alu.out.valid := valid

  io.feed.valid        := valid && isBranch
  io.feed.pc           := io.alu.ctrlFlow.pc
  io.feed.isMispred    := isMispred 
  io.feed.branchDest   := branchDest
  io.feed.branchTaken  := branchTaken 
  io.feed.operType     := oper
  io.feed.branchType   := LookupTree(oper, RV32I_BRUInstr.branchTypeList)
  io.feed.compressed   := compressed 

}
